class Solution {
public:
    int countNodes(TreeNode* root) {
        if (!root) return 0;
        int lh = 0, rh = 0;
        TreeNode* left = root->left;
        TreeNode* right = root->right;
        while (left)
        {
            ++lh;
            left = left->left;
        }
        while (right)
        {
            ++rh;
            right = right->right;
        }
        if (lh == rh) return (2 << lh) - 1;
        return countNodes(root->left) + countNodes(root->right) + 1;
    }
};